classdef signal_sep
    properties
        Fsb;
        Fsi;
    end
    
    methods
        function ss = signal_sep(fsb, fsi)
           ss.Fsb = fsb;
           ss.Fsi = fsi;
        end
    end
    
    methods (Static)
        %% general utility %%
        function behav = idx2end(ids, Fsb, Fsi)
            [l, n] = bwlabeln(ids);
            behav = zeros(2, n);
            for i = 1: n
                tmp = l == i;
                behav(1, i) = round(find(tmp, 1) * Fsi / Fsb);
                behav(2, i) = round(find(tmp, 1, 'last') * Fsi / Fsb);
            end
        end
        
        %% signal separation %%
        function dataout = spk_data_prep(datain)
            dff = datain{1};
            options.p = 2;
            
            %%% then prepare spk relavent variables %%%
            t = dff;
            gss = zeros(2, size(t, 1));
            bss = zeros(size(t, 1), 1);
            spkthress = zeros(size(t, 1), 1);
            spkss = t;
            dataos = t;
            parfor ii = 1: size(t, 1)
                tt = t(ii, :);
                [cc, cb, c1, gn, sn, spk] = constrained_foopsi(tt, [], 0, [], [], options);
                spkss(ii, :) = spk;
                gss(:, ii) = gn;
                bss(ii, :) = cb;
                spkthress(ii, :) = nanmean(spk) + 2 * mad(spk);
                dataos(ii, :) = cc + cb;
            end
            gs = gss;
            bs = bss;
            spkthres = spkthress;
            spks = spkss;
            datao = dataos;
            
            dataout = {spks, gs, bs, spkthres, datao};
        end
        
        function dataout = extract_factor(ss, datain, fflag)
            %%% needs assumptions: 1. minimal signal during equalization 2. constrain: all signals sum up to the whole %%% 
            spks = datain{1};
            gs = datain{2};
            bs = datain{3};
            idbs = datain{4};
            
            nr = 100;
            datafr = spks;
            
            %%% 1: parse behavior combination %%%
            idxall = ss.combine_idx(idbs, fflag);
            idpool = unique(idxall, 'rows');
            idwppool = idpool(idpool(:, 1) == 1, :);
            idcurr = idxall(:, 1);
            
            %%% 2: get weighted signal distribution %%%
            mx = max(spks(:));
            nn = size(spks, 1);
            nid = size(idwppool, 1);
            nidraw = size(idpool, 1);
            wts = zeros(1, nid);
            nedge = 101;
            dist = NaN(nid, nedge - 1);
            distraw = NaN(nidraw, nedge - 1);
            edges = linspace(0, mx, nedge);
            
            % reference all appearing modes %
            weight = zeros(1, size(idpool, 1));
            tpall = cell(size(idpool, 1), 1);
            for ii = 1: nidraw
                l1 = ismember(idxall, idpool(ii, :), 'rows');
                l1 = l1(:);
                tp1 = spks(:, l1);
                tp1 = tp1(:);
                tpall{ii} = tp1;
                weight(ii) = sum(l1);
                h1 = histcounts(tp1, edges);
                h1n = h1 / sum(h1);
                distraw(ii, :) = h1n;
            end
            
            % start separate per mode %
            idl = zeros(nid, length(idcurr));
            idtodo = [];
            for ii = 1: nid
                l1 = ismember(idxall, idwppool(ii, :), 'rows');
                idtmp = idwppool(ii, :);
                idtmp(1) = false;
                l2 = ismember(idxall, idtmp, 'rows');
                l1 = l1(:);
                l2 = l2(:);
                tp1 = spks(:, l1);
                tp2 = spks(:, l2);
                tp1 = tp1(:);
                tp2 = tp2(:);
                wts(ii) = length(tp1);
                idl(ii, :) = l1';
                
                if ~isempty(tp2) && ~isempty(tp1)
                    h1 = histcounts(tp1, edges);
                    h2 = histcounts(tp2, edges);
                    h1n = h1 / sum(h1);
                    h2n = h2 / sum(h2);
                    hd = ss.dist_equalize(h1n, h2n, sum(h1));
                    dist(ii, :) = ceil(hd);
                else
                    idtodo = [idtodo, ii];
                end
            end
            
            %%% estimate averaged existing distribution %%%
            ntd = length(idtodo);
            for ii = 1: ntd
                idtmp = idwppool(idtodo(1), :);
                idtmp = ismember(idpool, idtmp, 'rows');
                %                     ltmp = ismember(idpool, idtmp, 'rows');
                a = idpool';
                b = idwppool(idtodo(1), :)';
                %                     a(1, :) = false;
                b(1, :) = false;
                x = a \ b;
                if all(x == 0)
                    x = ones(size(x));
                    x = x / sum(x);
                end
                wtuse = weight;
                distuse = distraw;
                distuse = (distuse' * (x(:) .* wtuse(:)))';
                distuse = distuse / sum(distuse);
                h1 = histcounts(tpall{idtmp}, edges);
                h1n = h1 / sum(h1);
                hd = ss.dist_equalize(h1n, distuse, sum(h1));
                dist(idtodo(1), :) = ceil(hd);
                idtodo = idtodo(2: end);
            end
            
            %%% 3: random pick signal %%%
            idtt = cell(nr, 1);
            idd = cell(nr, 1);
            for jj = 1: nr
                idtt{jj} = false(size(spks));
                idd{jj} = [];
            end
            
            for ii = 1: nid
                ll1 = repmat(idl(ii, :), nn, 1);
                for jj = 1: nedge - 1
                    edget = edges(jj: jj + 1);
                    idt1 = find(ll1 > 0 & spks > edget(1) & spks <= edget(2));
                    for kk = 1: nr
                        idd{kk} = [idd{kk}; idt1(randsample(length(idt1), min(length(idt1), dist(ii, jj))))];
                    end
                end
            end
            
            for jj = 1: nr
                idtt{jj}(idd{jj}) = true;
            end
            
            %%% 4: get updated spk info %%%
            dtt = zeros([size(spks), nr]);
            for ii = 1: nr
                tmp = spks .* idtt{ii};
                dtt(:, :, ii) = tmp;
            end
            t = nanmean(dtt, 3);
            spksfr = t;
            
            %%% 5: get updated fluorescence %%%
            for ii = 1: nn
                datafr(ii, :) = ss.ar_integrate(t(ii, :), gs(:, ii)) + bs(ii);
            end
            
            dataout = {spksfr, datafr};
        end
        
        function idxall = combine_idx(idbs, fflag)
            %%% order: stim, wiping, paw motor (prin + sec), whisker motor, jaw motor %%%
            n = size(idbs, 2);
            switch fflag
                case 'extract_wipe'
                    id = [2: n, 1];
                case 'extract_whisk'
                    id = [5: n, 1: 4];
                case 'extract_stim'
                    id = 1: n;
                case 'extract_pawp'
                    id = [3: n, 1: 2];
                case 'extract_paws'
                    id = [4: n, 1: 3];
                case 'extract_jaw'
                    id = [6: n, 1: 5];
            end
            
            idxall = idbs(:, id);
        end

        function hd = dist_equalize(h1n, h2n, n1)
            nn = n1;
            thres = 1 / n1;
            rt = 0.1;
            hd = zeros(size(h1n));
            idreal = h1n > h2n;
            cc = 1;
            while 1
                d = abs(h1n - h2n);
                idcurr = h1n > h2n;
                ratio = sum(idreal & idcurr) / sum(idreal | idcurr);
                if all(d < thres) || nn < n1 / 2 || ratio < 0.9999
                    break
                else
                    dt = rt * max(0, h1n - h2n);
                    hd = hd + dt * nn;
                    nn = nn * (1 - sum(dt));
                    h1n = h1n - dt;
                    h1n = h1n / sum(h1n);
                end
                cc = cc + 1;
            end
            
            %     hd = ceil(hd);
        end
        
        function aa = ar_integrate(t, gn)
            aa = t;
            cm = zeros(length(gn), 1);
            for i = 1: length(t)
                aa(i) = cm' * gn + t(i);
                cm = [aa(i); cm(1: end - 1)];
            end
        end
        
        function idout = index_convert(idin, Fsb, Fsi)
            nn = size(idin, 2);
            idout = false(round(size(idin, 1) * Fsi / Fsb), size(idin, 2));
            for i = 1: nn
                tmp = idin(:, i);
                [l, n] = bwlabel(tmp);
                for j = 1: n
                    ttmp = l == j;
                    stt = find(ttmp, 1);
                    stp = find(ttmp, 1, 'last');
                    idout(round(stt * Fsi / Fsb): round(stp * Fsi / Fsb), i) = true;
                end
            end
        end
        
        %% visualization %%
        function plot_sig_behav(ss, dffs, behav, ttl)
            Fsi = ss.Fsi;
            figure(gcf)
            clf
            hold on
            plot((1: size(dffs, 2)) / Fsi, (10 * dffs + (1: size(dffs, 1))')', 'k')
            axis tight
            rgy = get(gca, 'ylim');
            xlabel('Time [s]')
            ylabel('Neuron #')
            title(ttl)
            if size(behav, 1) == 1
                behav = repmat(behav, 2, 1);
            end
            for i = 1: size(behav, 2)
                patch([behav(:, i); flipud(behav(:, i))] / Fsi, [rgy(1), rgy(1), rgy(2), rgy(2)], 'b', 'edgecolor', 'none')
                alpha(0.3)
            end
        end
    end
end
